{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 客服对话多标签生成\n",
    "\n",
    "<span style=\"font-size: 20px; font-weight: bold;\">注意：您使用该案例默认的数据和模型训练时，会产生一定费用。计费方式参考：https://cloud.baidu.com/doc/WENXINWORKSHOP/s/6lrk4bgxb</span>\n",
    "\n",
    "在客服对话场景中，可以通过大模型分析用户与客服之间的对话信息，准确识别用户的意图和对应原因，生成对应标签为后续回复和营销策略服务。比如，用户因地址填写错误与客服沟通，说希望能够取消订单，则取消订单是意图，地址填写错误是原因。最初我们选择小模型进行多标签生成，在初期使用中展现出一定的效果，能够在较短时间内进行部署并提供基础的标签生成功能。然而，随着业务需求的日益复杂，现有的小模型在多标签生成上存在一些明显的问题和挑战：\n",
    "\n",
    "\n",
    "* 标签准确率不高：小模型的标签准确率通常保持在接近80%左右，无法满足业务进一步提高准确率的期望。业务需求日益复杂，用户提出的问题多样化，小模型的识别能力有限，导致部分标签无法准确标识。\n",
    "* 标注数据需求量大：训练一个有效的小模型，每个标签至少需要300个标注数据，人工成本高昂。尤其是在业务需求不断变化的情况下，标注工作量进一步增加。\n",
    "* 对标签体系的依赖性强：小模型对标签体系还会有较强的依赖，一旦业务标签体系发生较大变化，例如三层标签扩展到四层标签或大规模调整标签结构，标注和训练工作需要大规模重复建设。\n",
    "\n",
    "\n",
    "针对上述问题，我们提出了使用ERNIE Tiny大模型进行微调的解决方案。ERNIE Tiny大模型在语言理解能力上更强，能够在较少标注数据的情况下，达到或超过现有小模型的准确率要求。通过微调训练，ERNIE Tiny大模型能够更好地适应业务需求的变化，提高多标签生成的准确性和效率，从而更好地支持客服对话场景的应用。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 0. 环境准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qianfan import ChatCompletion\n",
    "from qianfan.dataset import Dataset\n",
    "from qianfan.common import Prompt\n",
    "from qianfan.trainer import LLMFinetune\n",
    "from qianfan.trainer.consts import PeftType\n",
    "from qianfan.trainer.configs import TrainConfig\n",
    "import os\n",
    "from qianfan.dataset import Dataset\n",
    "from qianfan.dataset.data_source.base import FormatType"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"QIANFAN_ACCESS_KEY\"] = \"your_access_key\"\n",
    "os.environ[\"QIANFAN_SECRET_KEY\"] = \"your_secret_key\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. 基座模型效果示例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先，我们选择了ERNIE-Tiny-8K模型作为本次实验的基座模型。\n",
    "\n",
    "此处设置两个案例，带您直观了解微调前模型的输出问题"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "chat = ChatCompletion(model=\"ERNIE-Tiny-8K\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 例一"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "根据对话内容，最有可能的客户意图和对应的原因标签是：\n",
      "“客户意图”：订单取消\n",
      "“原因标签”：解决问题\n",
      "\n",
      "解释：客户李先生表示地址写错了需要取消订单，王琳K表示可以帮他取消订单并提供了相应的帮助。因此，可以判断客户的意图是订单取消。\n"
     ]
    }
   ],
   "source": [
    "target ={ \n",
    "    \"conversation\": (\n",
    "        \"王琳K：欢迎光临DianCan披萨，为了给您提供更加优质的服务，请问您有什么具体的问题或需要帮助吗？\"\n",
    "        \"客户李先生：我刚刚下了一个订单，但是地址写错了，能帮我取消吗？\"\n",
    "        \"王琳K：非常抱歉给您带来困扰，我可以帮您取消订单。为了确认您的身份，需要您提供订单号或者下单时使用的电话号码，可以吗？\"\n",
    "        \"客户李先生：我的订单号是DC123456789，电话号码是138****1234。\"\n",
    "        \"王琳K：非常感谢您提供的信息，我已经找到了您的订单。现在我将为您取消该订单，请稍等片刻。\"\n",
    "        \"客户李先生：好的，谢谢。\"\n",
    "        \"王琳K：您的订单已经成功取消。如有其他问题，请随时联系我们。感谢您的理解和支持。\"\n",
    "        \"客户李先生：非常感谢你们的帮助，我会重新下单的。\"\n",
    "        \"王琳K：非常高兴能够帮助您解决问题。祝您用餐愉快！如有其他问题，请随时联系我们。\"\n",
    ")\n",
    "}\n",
    "prompt = Prompt(\"\"\"你是一个对话意图识别并打标签的机器人，根据下面的已知信息，打上标签。\n",
    "                请使用以下格式输出：{\"意图\": xxx\n",
    "                                \"原因\": xxx}\n",
    "请根据以下会话的内容，精准判断最有可能的客户意图以及对应的原因标签，意图和原因标签必须严格控制在给定的范围之内。：\n",
    "\n",
    "{conversation}\n",
    "\"\"\")\n",
    "\n",
    "resp = chat.do(messages=[{\"role\": \"user\", \"content\": prompt.render(**target)[0]}])\n",
    "\n",
    "print(resp[\"result\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 例二"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "```json\n",
      "{\n",
      "\"意图\": \"客户服务\",\n",
      "\"原因\": \"顾客表达想要取消订单，处理地址错误的情况，机器人回应询问订单号和手机号码后四位以确认身份，并提供了处理订单取消及退款的具体措施\"\n",
      "}\n",
      "```\n"
     ]
    }
   ],
   "source": [
    "target ={ \n",
    "    \"conversation\": (\n",
    "        \"DianCan 自助点餐机器人：您好，欢迎光临DianCan披萨，有什么可以为您服务的吗？\"\n",
    "        \"顾客：你好，我刚刚下了一个订单，但是我发现我填写的地址是错的。\"\n",
    "        \"DianCan 自助点餐机器人：非常抱歉给您带来困扰。请问您是否希望取消订单并重新下单呢？\"\n",
    "        \"顾客：是的，我想取消订单。\"\n",
    "        \"DianCan 自助点餐机器人：好的，请告诉我您的订单号，我会尽快帮您处理。\"\n",
    "        \"顾客：我的订单号是XXXX。\"\n",
    "        \"DianCan 自助点餐机器人：好的，已经为您查询到了订单。为了确认您的身份，请问您能提供下单时使用的手机号码后四位吗？\"\n",
    "        \"顾客：手机号码后四位是XXXX。\"\n",
    "        \"DianCan 自助点餐机器人：非常感谢，已经确认您的身份。我们现在就为您取消订单，并会尽快处理退款。退款将在3-7个工作日内原路返回至您的支付账户。请您注意查收。\"\n",
    "        \"顾客：好的，非常感谢你们的帮助。\"\n",
    "        \"DianCan 自助点餐机器人：不客气，如果您还有其他问题或需要进一步的帮助，请随时与我们联系。祝您用餐愉快！\"\n",
    ")\n",
    "}\n",
    "\n",
    "resp = chat.do(messages=[{\"role\": \"user\", \"content\": prompt.render(**target)[0]}])\n",
    "\n",
    "print(resp[\"result\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从上述两组例子中，我们可以总结出以下问题：\n",
    "\n",
    "\n",
    "* 问题一：微调前模型的输出可能并不能完全遵循指定格式进行输出。\n",
    "* 问题二：基座模型的输出无法精准识别客户意图和原因。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. 模型精调数据准备"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.1 构造“意图-原因”标签集"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先，将收集到的数据集进行标签注解。\n",
    "\n",
    "根据客服对话的内容，提炼客户意见与相应的发起客服对话原因，进行数据编号的标注，如下为例："
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| 意见                       | 原因                          | 数据编号 |\n",
    "|----------------------------|-------------------------------|----------|\n",
    "| 如何提交评价               | 未收到评价请求               | 1        |\n",
    "| 你们有某种餐品吗          | 餐品缺货、短期或长缺         | 1        |\n",
    "| 某餐需要做成不辣          | 顾客特殊需求                 | 1        |\n",
    "| 在餐厅怎么买东西怎么回    | 找回遗失物品                 | 1        |\n",
    "| 订单什么时候能做好        | 餐品制作时间                 | 1        |\n",
    "| 如何访问我的订单历史记录 | 订单历史订单详情            | 1        |\n",
    "| 取消订单                  | 取消订单_具体原因           | 1        |\n",
    "| 为什么我的优惠券没见了    | 优惠券未到期                 | 1        |\n",
    "| 取消订单                  | 地址填写错误                 | 1        |\n",
    "| 我想要把单品或套餐添加/删除 | j1或添加单品_质量定制       | 2        |\n",
    "| 餐品不对                  | 源错误                       | 1        |\n",
    "| 餐品配送太少              | 食物波动                     | 1        |\n",
    "| 我在餐厅可以参与活动吗    | 活动相关咨询                 | 1        |\n",
    "| 有关于推荐的商品          | 需要推荐咨询                 | 1        |\n",
    "| 我在餐厅开发票            | 开发票                       | 2        |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 生成对话数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "根据上述案例，从”意图-原因“生成客服对话的Prompt如下所示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "conversation_prompt =\"\"\"你是一个利用【意图-原因】生成对话的机器人，请你仔细观察下面的输入输出，发挥你的想象，根据【输入】提供的已知信息和要求，生成客服对话。具体要求如下：\n",
    "1.【输入】和【输出】的格式与示例相同。\n",
    "2.【输出】为生成的对话。\n",
    "3.你生成的是DianCan披萨公司的客服对话，请你刻意回避百胜公司、肯德基餐饮相关的产品和名词。\n",
    "遵循以上准则，请你根据【输入】创造一个新的客服对话，示例如下：\n",
    "\n",
    "【输入】\n",
    "###意图-原因\n",
    "\"{\"意图\": \"某产品你们有吗\"\n",
    "\"原因\": \"产品是否有_餐厅断货_临时or永久\"}\"\n",
    "【输出】\n",
    "### 对话内容\n",
    "\"客服:正在为您转接人工服务中，目前人工繁忙，如需继续等待请输入：继续\n",
    "客服:欢迎进入人工客服通道，56959很高兴为您服务对话过程中以及完成后，您会收到评价提醒，希望您能对我个人本次服务做个评价，您的反馈和建议也是我努力的方向哦，感谢~\n",
    "顾客:继续\n",
    "顾客:你好，鸡腿饭现在是下架了嘛\n",
    "客服:DianCan客服中心，很高兴为您服务，我先查看一下您反馈的问题哦~\n",
    "顾客:附近每家店子都没有[嚎哭]\n",
    "客服:没有看到说明售罄，暂时断货，没有这个餐点，建议客官过段时间在购买查看的，不好意思。\n",
    "客服:以您在我们官网看到的为准，有就是有，没有就是没有的呢。\n",
    "顾客:那可以查一下附近哪家店有吗\n",
    "顾客:武汉洪山区哪家店有\n",
    "顾客:我看了好多家都没有[嚎哭]\n",
    "客服:小二这边是客服中心的，不是某家餐厅，不是很清楚每个门店的具体情况，非常抱歉。\n",
    "客服:或者您可以通过DianCan微信公众号-自助服务-点击入群，可加入附近DianCan餐厅的社群哦~了解具体信息。\n",
    "客服:亲亲，您还在线吗？我还在快马加鞭处理中，如果您还有问题，也及时回应哦~\n",
    "客服:\"\n",
    "---------------------------------\n",
    "【输入】    \n",
    "###意图-原因\n",
    "\"{\"意图\": \"你们有某种餐品吗\"\n",
    "\"原因\": \"餐品断货_短期或长期\"}\"\n",
    "【输出】：\"\"\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "客服：您好，目前正在为您查询，关于您提到的餐品，我们暂时没有收到关于断货或下架的信息。建议您可以通过我们的官网或者微信公众号了解最新产品信息。如果您需要其他帮助或有其他问题，请随时告知，我们会尽快为您处理。\n",
      "\n",
      "如果您对餐品的供应情况有疑问或需要了解更多信息，建议您通过我们的官方渠道查询最新消息。如果您还有其他问题或需求，请随时告知，我们会竭诚为您服务。\n"
     ]
    }
   ],
   "source": [
    "resp = chat.do(messages=[{\"role\": \"user\", \"content\": conversation_prompt}])\n",
    "print(resp[\"result\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. SFT调优示例"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.1 数据集导入"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在完成上述的数据集准备工作后，我们可以开始进行模型微调训练。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先从平台中获取微调用的训练集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO][2024-08-12 16:47:47.213] dataset.py:430 [t:8570851136]: no data source was provided, construct\n",
      "[INFO][2024-08-12 16:47:47.214] dataset.py:282 [t:8570851136]: construct a qianfan data source from existed id: ds-scm8g98a7pv3zzf3, with args: {'format': <FormatType.Jsonl: 'jsonl'>}\n",
      "[INFO][2024-08-12 16:47:47.989] dataset_utils.py:317 [t:8570851136]: list qianfan dataset data by 0\n",
      "[INFO][2024-08-12 16:47:48.404] dataset_utils.py:339 [t:8570851136]: received dataset list from qianfan dataset\n",
      "[INFO][2024-08-12 16:47:48.405] dataset_utils.py:347 [t:8570851136]: retrieve single entity from https://easydata.bj.bcebos.com/_system_/dataset/ds-scm8g98a7pv3zzf3/texts/data/raw_aca7da8ef71c956315d9a7dc3874a4d5a65280bb382263c39aa16482de1b666e_91dc9885731b405bb32a5d5734c4dd5f?authorization=bce-auth-v1%2F50c8bb753dcb4e1d8646bb1ffefd3503%2F2024-08-12T08%3A47%3A48Z%2F7200%2Fhost%2F4042ce5b329d3026629877a2108adae8293bc615d5b4cc9c76948f9a2a917147 in try 0\n",
      "[INFO][2024-08-12 16:47:48.631] dataset_utils.py:361 [t:8570851136]: retrieve single entity from https://easydata.bj.bcebos.com/_system_/dataset/ds-scm8g98a7pv3zzf3/texts/data/raw_aca7da8ef71c956315d9a7dc3874a4d5a65280bb382263c39aa16482de1b666e_91dc9885731b405bb32a5d5734c4dd5f?authorization=bce-auth-v1%2F50c8bb753dcb4e1d8646bb1ffefd3503%2F2024-08-12T08%3A47%3A48Z%2F7200%2Fhost%2F4042ce5b329d3026629877a2108adae8293bc615d5b4cc9c76948f9a2a917147 succeeded, with content: [{\"prompt\": \"假设你有一套客户意图分类以及该分类下属的原因标签。请根据给定的客服对话内容，判断最有可能的客户意图以及对应的原因标签，意图和原因标签需要严格控制给定的范围之内；一个意图可能对应多个原因，但一个原因只会对应一个意图；如果均不匹配则回答无明确客户意图；回答请使用json的格式，示例：'{\\\"意图\\\": \\\"xxx\\\",\\\"原因”: \\\"xxx\\\"}'\\n### 下面是客户意图的分类\\n1.客户意图：如何提交评价;原因标签：未收到评价邀请\\n2.客户意图：你们有某种餐品吗;原因标签：餐品缺货_短期或长期\\n3.客户意图：某餐品需要做成不辣;原因标签：顾客特需服务\\n4.客户意图：在餐厅丢失了物品怎么寻回;原因标签：找回遗失物品\\n5.客户意图：订单什么时候能做好;原因标签：餐品制作时间\\n6.客户意图：如何访问我的订单历史记录;原因标签：历史订单查询\\n7.客户意图：取消订单;原因标签：取消订单_无具体理由\\n8.客户意图：为什么我的优惠券不见了;原因标签：优惠券未到账\\n9.客户意图：取消订单;原因标签：地址填写错误\\n10.客户意图：我想要加番茄酱或者不加番茄酱;原因标签：加or不加番茄酱_顾客定制\\n11.客户意图：餐品不对;原因标签：漏餐错餐\\n12.客户意图：餐厅电话是多少;原因标签：食物变质\\n13.客户意图：我在哪里可以参加活动;原因标签：活动地点咨询\\n14.客户意图：有没有推荐的产品;原因标签：需要推荐餐品\\n15.客户意图：客户要开发票;原因标签：开发票\\n\\n###对话内容\\n李星辰DY    2023年07月19日 10:23:48\\n您好，欢迎光临DianCan披萨，有什么可以为您服务的吗？\\n用户673210    2023年07月19日 10:24:12\\n我想开一下发票\\n李星辰DY    2023年07月19日 10:24:35\\n当然可以，请您提供一下订单号和开票信息，我们会尽快为您处理。\\n用户673210    2023年07月19日 10:25:01\\n订单号是DC230715001，开票信息是公司名称：XX科技有限公司，税号：9132XXXXXXXXX\\n李星辰DY    2023年07月19日 10:25:38\\n好的，已经收到您的订单号和开票信息，我们会尽快为您开具发票并发送到您的邮箱。请问您的邮箱地址是什么？\\n用户673210    2023年07月19日 10:26:05\\n我的邮箱是[example@example.com](mailto:example@example.com)\\n李星辰DY    2023年07月19日 10:26:32\\n非常感谢，我们已经记录下了您的邮箱地址。发票将在24小时内发送到您的邮箱，请注意查收。\\n用户673210    2023年07月19日 10:27:00\\n好的，谢谢！\\n李星辰DY    2023年07月19日 10:27:25\\n不客气，如果您有任何其他问题或需要进一步的帮助，请随时联系我们。祝您用餐愉快！\\n\\n###输出\", \"response\": [[\"{'意图': '客户要开发票', '原因': '开发票'}\"]]}]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'entity_id': 'aca7da8ef71c956315d9a7dc3874a4d5a65280bb382263c39aa16482de1b666e_91dc9885731b405bb32a5d5734c4dd5f', 'entity_content': '[{\"prompt\": \"假设你有一套客户意图分类以及该分类下属的原因标签。请根据给定的客服对话内容，判断最有可能的客户意图以及对应的原因标签，意图和原因标签需要严格控制给定的范围之内；一个意图可能对应多个原因，但一个原因只会对应一个意图；如果均不匹配则回答无明确客户意图；回答请使用json的格式，示例：\\'{\\\\\"意图\\\\\": \\\\\"xxx\\\\\",\\\\\"原因”: \\\\\"xxx\\\\\"}\\'\\\\n### 下面是客户意图的分类\\\\n1.客户意图：如何提交评价;原因标签：未收到评价邀请\\\\n2.客户意图：你们有某种餐品吗;原因标签：餐品缺货_短期或长期\\\\n3.客户意图：某餐品需要做成不辣;原因标签：顾客特需服务\\\\n4.客户意图：在餐厅丢失了物品怎么寻回;原因标签：找回遗失物品\\\\n5.客户意图：订单什么时候能做好;原因标签：餐品制作时间\\\\n6.客户意图：如何访问我的订单历史记录;原因标签：历史订单查询\\\\n7.客户意图：取消订单;原因标签：取消订单_无具体理由\\\\n8.客户意图：为什么我的优惠券不见了;原因标签：优惠券未到账\\\\n9.客户意图：取消订单;原因标签：地址填写错误\\\\n10.客户意图：我想要加番茄酱或者不加番茄酱;原因标签：加or不加番茄酱_顾客定制\\\\n11.客户意图：餐品不对;原因标签：漏餐错餐\\\\n12.客户意图：餐厅电话是多少;原因标签：食物变质\\\\n13.客户意图：我在哪里可以参加活动;原因标签：活动地点咨询\\\\n14.客户意图：有没有推荐的产品;原因标签：需要推荐餐品\\\\n15.客户意图：客户要开发票;原因标签：开发票\\\\n\\\\n###对话内容\\\\n李星辰DY    2023年07月19日 10:23:48\\\\n您好，欢迎光临DianCan披萨，有什么可以为您服务的吗？\\\\n用户673210    2023年07月19日 10:24:12\\\\n我想开一下发票\\\\n李星辰DY    2023年07月19日 10:24:35\\\\n当然可以，请您提供一下订单号和开票信息，我们会尽快为您处理。\\\\n用户673210    2023年07月19日 10:25:01\\\\n订单号是DC230715001，开票信息是公司名称：XX科技有限公司，税号：9132XXXXXXXXX\\\\n李星辰DY    2023年07月19日 10:25:38\\\\n好的，已经收到您的订单号和开票信息，我们会尽快为您开具发票并发送到您的邮箱。请问您的邮箱地址是什么？\\\\n用户673210    2023年07月19日 10:26:05\\\\n我的邮箱是[example@example.com](mailto:example@example.com)\\\\n李星辰DY    2023年07月19日 10:26:32\\\\n非常感谢，我们已经记录下了您的邮箱地址。发票将在24小时内发送到您的邮箱，请注意查收。\\\\n用户673210    2023年07月19日 10:27:00\\\\n好的，谢谢！\\\\n李星辰DY    2023年07月19日 10:27:25\\\\n不客气，如果您有任何其他问题或需要进一步的帮助，请随时联系我们。祝您用餐愉快！\\\\n\\\\n###输出\", \"response\": [[\"{\\'意图\\': \\'客户要开发票\\', \\'原因\\': \\'开发票\\'}\"]]}]'}]\n"
     ]
    }
   ],
   "source": [
    "ds = Dataset.load(qianfan_dataset_id = \"ds-scm8g98a7pv3zzf3\", format = FormatType.Jsonl)\n",
    "print(ds[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2 微调训练\n",
    "\n",
    "拿到一个训练场景或者任务后，往往比较难判断参数应该如何调整。一般使用默认的参数值进行训练即可，平台中的默认参数是多次实验的经验结晶。 接下来介绍参数配置中有两个较为关键的参数：\n",
    "\n",
    "* 迭代轮次（Epoch）: 控制训练过程中的迭代轮数。轮数增加代表会使用训练集对模型训练一次。\n",
    "\n",
    "* 学习率（Learning Rate）: 是在梯度下降的过程中更新权重时的超参数，过高会导致模型难以收敛，过低则会导致模型收敛速度过慢，平台已给出默认推荐值，也可根据经验调整。\n",
    "\n",
    "* 序列长度：如果对话数据的长度较短，建议选择短的序列长度，可以提升训练的速度。\n",
    "\n",
    "本次也针对Epoch和Learning Rate进行简要的调参实验，详细实验结果可以看效果评估数据。\n",
    "\n",
    "如果您是模型训练的专家，千帆也提供了训练更多的高级参数供您选择。这里也建议您初期调参时步长可以设定稍大些，因为较小的超参变动对模型效果的影响小，会被随机波动掩盖。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建trainer任务"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = LLMFinetune(\n",
    "    name = \"dialogue-multi-tag\",\n",
    "    train_type=\"ERNIE-Tiny-8K\",\n",
    "    train_config=TrainConfig(\n",
    "        epoch=1,\n",
    "        learning_rate=1e-5,\n",
    "        peft_type=PeftType.ALL,\n",
    "    ),\n",
    "    dataset=ds\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "启动训练任务"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO][2024-08-12 16:51:44.361] base.py:226 [t:8570851136]: trainer subprocess started, pid: 11416\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[None]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO][2024-08-12 16:51:44.368] base.py:202 [t:8570851136]: check running log in .qianfan_exec_cache/ZdTRr7iw/2024-08-12.log\n"
     ]
    }
   ],
   "source": [
    "trainer.start()\n",
    "print(trainer.result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO][2024-08-12 17:23:18.156] dataset.py:430 [t:8570851136]: no data source was provided, construct\n",
      "[INFO][2024-08-12 17:23:18.157] dataset.py:282 [t:8570851136]: construct a qianfan data source from existed id: ds-scm8g98a7pv3zzf3, with args: {}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'datasets': {'versions': [{'versionId': 'ds-scm8g98a7pv3zzf3'}],\n",
       "  'sourceType': 'Platform',\n",
       "  'splitRatio': 20},\n",
       " 'task_id': 'task-85tcmxg0try3',\n",
       " 'job_id': 'job-xhk1gtuvvdbh',\n",
       " 'metrics': {'BLEU-4': '99.58%',\n",
       "  'ROUGE-1': '99.62%',\n",
       "  'ROUGE-2': '99.60%',\n",
       "  'ROUGE-L': '99.75%',\n",
       "  'EDIT-DISTANCE': '0.11',\n",
       "  'EMBEDDING-DISTANCE': '0.00'},\n",
       " 'checkpoints': [],\n",
       " 'model_set_id': 'am-mgf6icsebsa4',\n",
       " 'model_id': 'amv-8t6qf24m4xcb'}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.3 结果评估"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这一部分中，我们可以对刚才训练好的模型进行评估，评估模型的微调效果。\n",
    "\n",
    "针对客服对话意图识别场景的评估任务，我们需要制定如下规则与方法：\n",
    "\n",
    "* 评估规则：在客服对话多标签生成中，结果相对清晰明确。可定义评分规则，评分按照1-2-3三档执行，其中1分表示输出格式不对；2分表示输出格式正确，但内容不正确；3分表示输出的格式和内容全部正确。\n",
    "\n",
    "\n",
    "* 评估方式：实施自动打分，将评分标准、场景要求、上下文、大模型的回答拼接成一个Prompt，使用EB4等大模型进行自动打分。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先，导入训练好的模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[TRACE][2024-08-14 16:10:55.096] base.py:175 [t:8570851136]: raw request: QfRequest(method='POST', url='https://qianfan.baidubce.com/wenxinworkshop/modelrepo/modelVersionDetail', query={}, headers={'Content-Type': 'application/json', 'Host': 'qianfan.baidubce.com', 'request-source': 'qianfan_py_sdk_v0.4.5', 'x-bce-date': '2024-08-14T08:10:55Z', 'Authorization': 'bce-auth-v1/2d9f701d872f4f54b69274e9a17ff5b2/2024-08-14T08:10:55Z/300/x-bce-date;content-type;request-source;host/94f6326ddcd96fb90a8bfb86ca08d709607160be6334525bdbf77cfcbd6230e3'}, json_body={'modelVersionId': 'amv-8t6qf24m4xcb'}, files={}, retry_config=RetryConfig(retry_count=1, timeout=60, max_wait_interval=120.0, backoff_factor=0, jitter=1.0, retry_err_codes={500000, 18, 336100}))\n"
     ]
    }
   ],
   "source": [
    "from qianfan.model import Model\n",
    "\n",
    "# 从`version_id`构造模型：\n",
    "m = Model(id='amv-8t6qf24m4xcb')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "导入相应训练集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO][2024-08-14 16:43:13.466] dataset.py:430 [t:8570851136]: no data source was provided, construct\n",
      "[INFO][2024-08-14 16:43:13.466] dataset.py:282 [t:8570851136]: construct a qianfan data source from existed id: ds-n1dg1czx3ciqrakr, with args: {'input_columns': ['prompt'], 'reference_column': 'response'}\n",
      "[TRACE][2024-08-14 16:43:13.467] base.py:175 [t:8570851136]: raw request: QfRequest(method='POST', url='https://qianfan.baidubce.com/wenxinworkshop/dataset/info', query={}, headers={'Content-Type': 'application/json', 'Host': 'qianfan.baidubce.com', 'request-source': 'qianfan_py_sdk_v0.4.5', 'x-bce-date': '2024-08-14T08:43:13Z', 'Authorization': 'bce-auth-v1/2d9f701d872f4f54b69274e9a17ff5b2/2024-08-14T08:43:13Z/300/x-bce-date;content-type;request-source;host/c295693a48c205dd2be180918b1ec4c2a58fc373a2b296a57913ccd714f1b479'}, json_body={'datasetId': 'ds-n1dg1czx3ciqrakr'}, files={}, retry_config=RetryConfig(retry_count=1, timeout=60, max_wait_interval=120.0, backoff_factor=0, jitter=1.0, retry_err_codes={500000, 18, 336100}))\n"
     ]
    }
   ],
   "source": [
    "eval_ds = Dataset.load(qianfan_dataset_id =\"ds-n1dg1czx3ciqrakr\",organize_data_as_group=False, input_columns=[\"prompt\"], reference_column=\"response\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建本次评估任务\n",
    "\n",
    "\n",
    "本次评估采用了在线裁判员评估打分方式，因此使用了QianfanRefereeEvaluator作为本次的评估器。\n",
    "\n",
    "其中涉及到了以下参数：\n",
    "\n",
    "* prompt_steps:填写评估的打分方式与方法\n",
    "* prompt_metrics:分数最终呈现的形式。默认为“综合得分”\n",
    "* prompt_max_score:评估的最大分数。默认为3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qianfan.evaluation.evaluator import QianfanRefereeEvaluator, QianfanRuleEvaluator\n",
    "from qianfan.evaluation.consts import QianfanRefereeEvaluatorDefaultMetrics\n",
    "\n",
    "your_app_id = 105835511\n",
    "\n",
    "qianfan_evaluators = [\n",
    "    QianfanRefereeEvaluator(\n",
    "        prompt_steps= \"\"\"\n",
    "请你对模型的输出进行评分，阅读输出的格式与内容。评分按照1-2-3三档执行，其中1分表示输出格式不对；2分表示输出格式正确，但内容不正确；3分表示输出的格式和内容全部正确。\"\"\",\n",
    "        app_id=your_app_id,\n",
    "        prompt_metrics=QianfanRefereeEvaluatorDefaultMetrics,\n",
    "        prompt_max_score=3,\n",
    "    )\n",
    "    \n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "执行评估任务"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qianfan.evaluation import EvaluationManager\n",
    "\n",
    "em = EvaluationManager(qianfan_evaluators=qianfan_evaluators)\n",
    "result = em.eval([m], eval_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'m_xNDPALyMT4db_1': {'accuracy': 0, 'f1Score': 0, 'rouge_1': 0, 'rouge_2': 0, 'rouge_l': 0, 'bleu4': 0, 'avgJudgeScore': 2.951613, 'stdJudgeScore': 0.3779153, 'medianJudgeScore': 3, 'scoreDistribution': {'-1': 0, '0': 1, '1': 0, '2': 0, '3': 61}, 'manualAvgScore': 0, 'goodCaseProportion': 0, 'subjectiveImpression': '', 'manualScoreDistribution': None}}\n"
     ]
    }
   ],
   "source": [
    "print(result.metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为了体现微调的提升效果，在此对对未作`精调前`的模型进行评估："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'ERNIE Tiny_ERNIE-Tiny-8K': {'accuracy': 0, 'f1Score': 0, 'rouge_1': 0, 'rouge_2': 0, 'rouge_l': 0, 'bleu4': 0, 'avgJudgeScore': 1.3548387, 'stdJudgeScore': 0.9688166, 'medianJudgeScore': 1, 'scoreDistribution': {'-1': 0, '0': 14, '1': 20, '2': 20, '3': 8}, 'manualAvgScore': 0, 'goodCaseProportion': 0, 'subjectiveImpression': '', 'manualScoreDistribution': None}}\n"
     ]
    }
   ],
   "source": [
    "raw_model = Model(id = \"amv-sb5kfqie51z1\")\n",
    "raw_result = em.eval([raw_model], eval_ds)\n",
    "print(raw_result.metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "根据结果可以看到：基于ERNIE-Tiny-8K模型使用SFT全量更新的训练方法，且训练参数Epoch=1、Learning Rate=1e-5、序列长度=4096时，模型输出效果远远优于微调之前。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| 模型                        | 训练方法 | Epoch | Learning Rate | 序列长度 | 评估结果-0分 | 评估结果-1分 | 评估结果-2分 | 评估结果-3分 | 评估结果  |\n",
    "|-----------------------------|----------|-------|---------------|----------|--------|------|--------------|--------------|------------|\n",
    "| ERNIE-Tiny-8K (微调前)      | --       | --    | --            | --   |   14    | 20           | 20           | 8           | 0.9688166   |\n",
    "| ERNIE-Tiny-8K (微调后)      | 全量更新 | 1     | 1e-5          | 4096     |   0   | 1            | 0            | 61           | 2.951613   |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
